import os
from collections import Counter

import numpy as np

def buil_domain_svm_dataset(datadir, domain_name):

	cor_path = os.path.join(datadir,'Cor/cor_'+domain_name+'.txt')
	label_path = os.path.join(datadir,'Label/label_'+domain_name+'.txt')

	savedir = os.path.join(datadir,'forTC/cor_'+domain_name)
	if not os.path.exists(savedir):
		os.mkdir(savedir)
	label2id_savepath = os.path.join(savedir,'label2id.txt')
	train_savepath = os.path.join(savedir,'train.txt')
	val_savepath = os.path.join(savedir,'val.txt')
	test_savepath = os.path.join(savedir,'test.txt')

	train_cid_savepath = os.path.join(savedir,'train_cid.txt')
	val_cid_savepath = os.path.join(savedir,'val_cid.txt')
	test_cid_savepath = os.path.join(savedir,'test_cid.txt')

	label2id = {}
	lid = 0

	labels = []
	labels_idx = []
	with open(label_path,'r',encoding='utf8') as fr:
		for i, line in enumerate(fr):
			label = line.strip()
			if label:
				labels.append(label)
				if label not in label2id:
					label2id[label] = lid
					lid += 1

					labels_idx.append([i])
				else:
					labels_idx[label2id[label]].append(i)

	labels_cnt = [len(ls) for ls in labels_idx]

	# save label2id
	with open(label2id_savepath,'w',encoding='utf8',newline='\n') as fw:
		for l, lid in label2id.items():
			fw.write(str(l)+'\t'+str(lid)+'\n')

	# shuffle
	np.random.seed(170311)
	l_idx = np.arange(len(labels))
	for i in range(len(labels_idx)):
		np.random.shuffle(labels_idx[i])

	# segmentate dataset into train/val/test
	test_ratio = 2/3
	train_ratio = (7/8) * (1/3)

	train_idx = []
	val_idx = []
	test_idx = []

	for ls_idx in labels_idx:
		total_cnt = len(ls_idx)
		test_cnt = int(test_ratio * total_cnt)
		train_cnt = int(train_ratio * total_cnt)

		test_idx.extend(ls_idx[:test_cnt])
		train_idx.extend(ls_idx[test_cnt:(test_cnt+train_cnt)])
		val_idx.extend(ls_idx[(test_cnt+train_cnt):])


	with open(test_cid_savepath,'w',encoding='utf8') as fw:
		for idx in test_idx:
			fw.write(str(idx))
			fw.write('\n')

	with open(train_cid_savepath,'w',encoding='utf8') as fw:
		for idx in train_idx:
			fw.write(str(idx))
			fw.write('\n')

	with open(val_cid_savepath,'w',encoding='utf8') as fw:
		for idx in val_idx:
			fw.write(str(idx))
			fw.write('\n')


	# load cor
	cor = []
	counter = Counter()
	with open(cor_path,'r',encoding='utf8') as fr:
		for i,line in enumerate(fr):
			line = line.strip()
			words = line.split(' ')
			cor.append(line)
			counter.update(words)


	with open(train_savepath,'w',encoding='utf8',newline='\n') as fw:
		for idx in train_idx:
			if len(cor[idx]) > 0:
				fw.write(str(labels[idx]))
				fw.write('\t')
				fw.write(cor[idx])
				fw.write('\n')
			else:
				print('empty line:',idx)


	with open(val_savepath,'w',encoding='utf8',newline='\n') as fw:
		for idx in val_idx:
			if len(cor[idx]) > 0:
				fw.write(str(labels[idx]))
				fw.write('\t')
				fw.write(cor[idx])
				fw.write('\n')
			else:
				print('empty line:',idx)

	with open(test_savepath,'w',encoding='utf8',newline='\n') as fw:
		for idx in test_idx:
			if len(cor[idx]) > 0:
				fw.write(str(labels[idx]))
				fw.write('\t')
				fw.write(cor[idx])
				fw.write('\n')
			else:
				print('empty line:',idx)


domain_name_list = [str(i) for i in range(9)]
datadir = './data_correct/Ama'
for domain_name in domain_name_list:
	buil_domain_svm_dataset(datadir, domain_name)
